
import torch
from modules import *
import torch.nn.init as torch_init


def weight_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        torch_init.xavier_uniform_(m.weight)
        # m.bias.data.fill_(0.1)


class XModel(nn.Module):
    def __init__(self, cfg):
        super(XModel, self).__init__()
        self.t = cfg.t_step
        self.self_attention = XEncoder(
            d_model=cfg.feat_dim,
            hid_dim=cfg.hid_dim,
            out_dim=cfg.out_dim,
            n_heads=cfg.head_num,
            win_size=cfg.win_size,
            dropout=cfg.dropout,
            gamma=cfg.gamma,
            bias=cfg.bias,
            norm=cfg.norm,
        )
        self.classifier = nn.Conv1d(cfg.out_dim, 1, self.t, padding=0)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / cfg.temp))
        self.apply(weight_init)

    def forward(self, x, seq_len, scale):
        x_e, x_v = self.self_attention(x, seq_len, scale)
        #logits = F.pad(x_e, (self.t - 1, 0))
        #logits = self.classifier(logits)

        #logits = logits.permute(0, 2, 1)
        #logits = torch.sigmoid(logits)

        pad_logits = F.pad(x_e, (self.t - 1, 0))
        
        logits = self.conv_with_stride(pad_logits, 1)
        logits_1 = self.conv_with_stride(pad_logits, 2)
        logits_2 = self.conv_with_stride(pad_logits, 3)
        logits_3 = self.conv_with_stride(pad_logits, 4)

        return logits, x_v, (logits_1, logits_2, logits_3)

    def conv_with_stride(self, x, stride):
        self.classifier.stride = stride
        logits = self.classifier(x)
        logits = logits.permute(0, 2, 1)
        return torch.sigmoid(logits)
    
class LinearPredictor(nn.Module):
    def __init__(self, input_dim, num_filters, output_dim):
        """
        input_dim: 输入向量的维度
        num_filters: 卷积层的滤波器数量
        output_dim: 输出维度，这里为1
        """
        super(LinearPredictor, self).__init__()
        # 一维卷积层，输入通道数为input_dim，输出通道数为num_filters
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=num_filters, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=num_filters, out_channels=num_filters*2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(in_channels=num_filters*2, out_channels=512, kernel_size=3, padding=1)
        self.conv4 = nn.Conv1d(in_channels=512, out_channels=128, kernel_size=3, padding=1)
        self.drop = nn.Dropout(0.1)
        # 线性层
        self.fc = nn.Linear(128, output_dim)

    def forward(self, x):
        """
        x: 输入数据，形状为[batch_size, seq_len, input_dim]
        """
        # 转换输入数据的形状以匹配卷积层的期望输入 [batch_size, input_dim, seq_len]
        x = x.permute(0, 2, 1)
        # 通过两个卷积层
        x = self.drop(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = self.drop(F.relu(self.conv3(x)))
        x = F.relu(self.conv4(x))
        # 全局平均池化，将序列长度的维度池化成1
        x = F.max_pool1d(x, x.size(2))
        # 将池化后的输出压平，以便输入到线性层
        x = x.view(x.size(0), -1)
        # 线性层得到最终预测值
        x = self.fc(x)
        x = torch.sigmoid(x.squeeze())
        return x